torchrun --master-port 1111 --nproc_per_node=4 train/train_legacy.py \
    --model_name_or_path /your_model_path/Yarn-Llama-2-7b-128k \
    --llama_type llama2 \
    --data_path /your_data_path/llama2_pg19_8k_data \
    --output_dir /your_checkpoint_path/adapter_ckpts_llama2 \
    --max_steps 200 \
    --per_device_train_batch_size 3 \
    --gradient_accumulation_steps 10 \
    --save_steps 200 \
    --learning_rate 5e-3 \
    --weight_decay 0.1 \
    --warmup_steps 50 \
    --lr_scheduler_type cosine \
    --logging_steps 5 \
    --report_to tensorboard \
    --bf16 True \
    --medusa_heads 3 \
    --remove-unused-columns false
